#include "config.h"

#include <stdio.h>
#include <math.h>

#define DBG_SUBSYS S_LIBSTORAGE

#include "adt.h"
#include "sysy_lib.h"
#include "ytime.h"
#include "token_bucket.h"

#define DYNAMIC_BUCKET_USE_LEVEL_BUCKET FALSE

#define USE_LEAKY_BUCKET 1


int token_bucket_init(token_bucket_t *bucket, const char *name,
                      uint64_t capacity, uint64_t rate, uint64_t burst_max,
                      int private, int leaky)
{
        strcpy(bucket->name, name);
        bucket->capacity = capacity;
        bucket->tokens = capacity;
        bucket->rate = rate;

        bucket->t1 = ytime_gettime();
        bucket->t2 = bucket->t1;

        //burst_max must greater than rate
        bucket->burst_max = _max(rate, burst_max);

        sy_rwlock_init(&bucket->lock, "token_bucket.lock");

        bucket->include_leaky = leaky;
        if (bucket->include_leaky) {
                leaky_bucket_init(&bucket->lb, bucket->burst_max - bucket->rate, bucket->rate);
        }

        bucket->priv = private;
        bucket->inited = 1;

        // DWARN("name %s private %d leaky %d capacity %0.1f rate %0.1f burst %ju\n",
        //       name, bucket->priv, bucket->include_leaky, bucket->capacity, bucket->rate, bucket->burst_max);

        return 0;
}

void token_bucket_destroy(token_bucket_t *bucket)
{
        if (bucket->inited) {
                sy_rwlock_destroy(&bucket->lock);

                if (bucket->include_leaky) {
                        leaky_bucket_destroy(&bucket->lb);
                }

                bucket->inited = 0;
        }
}

int token_bucket_set(token_bucket_t *bucket, const char *name, uint64_t capacity, uint64_t rate, uint64_t burst_max,
                     int private, int leaky)
{
        int ret;

        if (!bucket->inited) {
                token_bucket_init(bucket, name, capacity, rate, burst_max, private, leaky);
        } else {
                ret = sy_rwlock_wrlock(&bucket->lock);
                if (unlikely(ret)) {
                        GOTO(err_ret, ret);
                }

                bucket->capacity = capacity;
                bucket->rate = rate;
                bucket->burst_max = _max(rate, burst_max);

                //DBUG("private %d capacity %0.1f rate %0.1f burst %0.1f\n",
                //      bucket->priv, bucket->capacity, bucket->rate, bucket->burst_max);

                sy_rwlock_unlock(&bucket->lock);
        }

        if (bucket->include_leaky) {
                leaky_bucket_set(&bucket->lb, bucket->burst_max - bucket->rate, bucket->rate);
        }

        return 0;
err_ret:
        return ret;
}

static inline int _token_bucket_update_tokens(token_bucket_t *bucket, suseconds_t now)
{
        double delta = 0.0;

        now = ytime_gettime();
        if (bucket->tokens < bucket->capacity) {
                /*　如果系统时间修改， delta　可能时负数*/
                delta = 1.0 * bucket->rate * (now - bucket->t1) / USECONDS_PER_SEC;
                bucket->tokens += fabs(delta);
        }

        // tokens不能大于capactiy
        if (bucket->tokens > bucket->capacity) {
                bucket->tokens = bucket->capacity;
        }

        // TODO float point?
        YASSERT(bucket->tokens >= 0 && bucket->tokens <= bucket->capacity + 1e-6);
        DBUG("timestamp %ju, capacity %0.2f, rate %0.2f, tokens %0.6f, delta %0.6f\n",
              now, bucket->capacity, bucket->rate, bucket->tokens, delta);

        bucket->t1 = now;

        return 0;
}

int token_bucket_consume(token_bucket_t *bucket, uint64_t n, int *is_ready, uint64_t *delay)
{
        int ret;

        *is_ready = 0;

        if (unlikely(!bucket->priv)) {
                ret = sy_rwlock_wrlock(&bucket->lock);
                if (unlikely(ret)) {
                        GOTO(err_ret, ret);
                }
        }

        ytime_t now = ytime_gettime();

        // Limit outflow speed
#if 0
        double delta = 1.0 * bucket->burst_max * (now - bucket->t2) / USECONDS_PER_SEC;
        if (unlikely(delta < n)) {
                DBUG("token bucket overflow, wait burst_max:%f\n", bucket->burst_max);
                goto out;
        }
#endif

        _token_bucket_update_tokens(bucket, now);

        if (n <= (int)bucket->tokens) {
                // TODO 增加条件: 控制最新采样周期的消费量
#if USE_LEAKY_BUCKET
                if (likely(bucket->include_leaky)) {
                        leaky_bucket_take(&bucket->lb, now, is_ready, NULL);
                        if (!*is_ready)
                                goto out;
                }
#endif

                bucket->tokens -= n;

                // 消耗成功， 更新上次消耗token时间
                bucket->t2 = now;

                *is_ready = 1;
        } else {
                uint64_t _delay = (uint64_t)(USECONDS_PER_SEC * (n - bucket->tokens) / bucket->rate);
                DBUG("n %d tokens %.2f rate %.2f delay %ju\n", n, bucket->tokens, bucket->rate, _delay);

                if (delay) {
                        *delay = _delay;
                }
        }

out:
        if (unlikely(!bucket->priv)) {
                sy_rwlock_unlock(&bucket->lock);
        }

        return 0;
err_ret:
        return ret;
}

int token_bucket_consume_loop(token_bucket_t *tb, uint64_t n, int us, int retry_max)
{
        int ret = 0, is_ready = 0, retry = 0;

        while (1) {
                if (retry_max > 0 && retry > retry_max) {
                        ret = EAGAIN;
                        DBUG("max %d retry %d ret %d\n", retry_max, retry, ret);
                        break;
                }

                retry++;

                token_bucket_consume(tb, n, &is_ready, NULL);
                if (unlikely(is_ready)) {
                        break;
                } else {
                        usleep(us);
                        continue;
                }
        }

        return ret;
}

int token_bucket_inc(token_bucket_t *bucket, uint64_t n)
{
        int ret;

        if (unlikely(!bucket->priv)) {
                ret = sy_rwlock_wrlock(&bucket->lock);
                if (ret) {
                        GOTO(err_ret, ret);
                }
        }

        bucket->tokens += n;

        if (bucket->tokens > bucket->capacity) {
                bucket->tokens = bucket->capacity;
        }

        if (unlikely(!bucket->priv)) {
                sy_rwlock_unlock(&bucket->lock);
        }

        return 0;
err_ret:
        return ret;
}

void level_bucket_init(level_bucket_t *bucket, int avg, int burst_max, int burst_time)
{
        uint64_t tmo;

        tmo = ytime_gettime();
        level_bucket_set(bucket, avg, burst_max, burst_time);

        bucket->rate_expire = tmo + USEC_PER_SEC;
        bucket->level = bucket->avg;
        bucket->level_expire = tmo + BUCKET_LEVEL_UPDATE_INTERVAL;
}

void level_bucket_set(level_bucket_t *bucket, int avg, int burst_max, int burst_time)
{
        bucket->avg = avg;

        if (burst_max > 0 && burst_max > avg) {
                bucket->burst_max = burst_max;
        } else {
                bucket->burst_max = avg + avg/10;
        }

        if (burst_time > 0) {
                bucket->burst_time = burst_time;
        } else {
                bucket->burst_time = 10;
        }

        bucket->capacity = bucket->burst_max * bucket->burst_time;
}

static int __get_rate_delay_nolock(level_bucket_t *bucket, uint64_t now)
{
        if (bucket->level < 0) {
                return ((0 - bucket->level) * USEC_PER_SEC) / bucket->avg;
        }

        if (bucket->rate > bucket->burst_max) {
                return bucket->rate_expire - now;
        }

        YASSERT(0);
        return 0;
}

static void __update_rate_nolock(level_bucket_t *bucket, uint64_t now)
{
        int rate;

        if (now >= bucket->rate_expire) {
                rate = bucket->rate - bucket->burst_max;
                bucket->rate = rate > 0 ? rate: 0;
                bucket->rate_expire = now + USEC_PER_SEC;
        }
}

static void __update_rate_level_nolock(level_bucket_t *bucket, uint64_t now)
{
        int level;
        uint64_t diff, diff2;

        if (now >= bucket->level_expire) {
                diff = now - bucket->level_expire;
                //diff2 为 填充满bucket（容量为capacity）所需的时间
                diff2 = (bucket->capacity / bucket->avg) * USEC_PER_SEC +
                        ((bucket->capacity % bucket->avg) * USEC_PER_SEC)/10;

                if (diff >= diff2) {
                        bucket->level = bucket->capacity;
                } else {
                        level = bucket->level + bucket->avg;
                        bucket->level = level > bucket->capacity ? bucket->capacity : level;
                }

                bucket->level_expire = now + BUCKET_LEVEL_UPDATE_INTERVAL;
        }
}

//如果返回ECANCELED, 意思是超过限制，可以推迟delay执行
int level_bucket_request(level_bucket_t *bucket, suseconds_t *delay, long size)
{
        int ret;
        uint64_t now;

        now = ytime_gettime();
        __update_rate_nolock(bucket, now);
        __update_rate_level_nolock(bucket, now);

        if (bucket->level > 0 && bucket->rate < bucket->burst_max) {
                bucket->level -= size;
                bucket->rate += size;
        } else {
                bucket->level -= size;
                bucket->rate += size;
                *delay = __get_rate_delay_nolock(bucket, now);
                ret = ECANCELED;
                goto err_ret;
        }

        return 0;
err_ret:
        return ret;
}

void dynamic_bucket_init(dynamic_bucket_t *bucket, long avg)
{
        memset(bucket, 0x0, sizeof(*bucket));
        sy_spin_init(&bucket->lock);
        level_bucket_init(&bucket->bucket, avg, 0, 0);
}

static void __dynamic_bucket_update_seconds(dynamic_bucket_t *bucket, time_t now, long intokens, long outtokens)
{
        int next = 0, cnt, i;

        if (bucket->last == 0) {
                bucket->cur = 0;
                bucket->intokens[0] = intokens;
                bucket->outtokens[0] = outtokens;
                bucket->last = now;
                bucket->begin = now;
        } else if (now > bucket->last) {
                cnt = now - bucket->last;
                for (i = 1; i <= cnt; i++) {
                        next = (bucket->cur + i) % BUCKET_AVG_UPDATE_COUNT;
                        bucket->intokens[next] = 0;
                        bucket->outtokens[next] = 0;
                }

                YASSERT(next >= 0 && next < BUCKET_AVG_UPDATE_COUNT);
                bucket->cur = next;
                bucket->intokens[next] = intokens;
                bucket->outtokens[next] = outtokens;
                bucket->last = now;
        } else {
                bucket->intokens[bucket->cur] += intokens;
                bucket->outtokens[bucket->cur] += outtokens;
        }
}

static void __dynamic_bucket_update_hours(dynamic_bucket_t *bucket, time_t now, long intokens, long outtokens)
{
        int next = 0, cnt, i;
        time_t hour = now / SEC_PER_HOU;

        if (bucket->lasthour == 0) {
                bucket->curhour = 0;
                bucket->intokens1hour[0] = intokens;
                bucket->outtokens1hour[0] = outtokens;
                bucket->lasthour = hour;
                bucket->beginhour = hour;
        } else if (hour > bucket->lasthour) {
                cnt = hour - bucket->lasthour;
                for (i = 1; i <= cnt; i++) {
                        next = (bucket->curhour + i) % BUCKET_HOUR_UPDATE_COUNT;
                        bucket->intokens1hour[next] = 0;
                        bucket->outtokens1hour[next] = 0;
                }

                YASSERT(next >= 0 && next < BUCKET_HOUR_UPDATE_COUNT);
                bucket->curhour = next;
                bucket->intokens1hour[next] = intokens;
                bucket->outtokens1hour[next] = outtokens;
                bucket->lasthour = hour;
        } else {
                bucket->intokens1hour[bucket->curhour] += intokens;
                bucket->outtokens1hour[bucket->curhour] += outtokens;
        }
}

int dynamic_bucket_update(dynamic_bucket_t *bucket, long intokens, long outtokens)
{
        int ret, i;
        long incount = 0, outcount = 0, avg;
        time_t now;

        ret = sy_spin_lock(&bucket->lock);
        if (unlikely(ret))
                GOTO(err_ret, ret);

        now = gettime();
        __dynamic_bucket_update_seconds(bucket, now, intokens, outtokens);
        __dynamic_bucket_update_hours(bucket, now, intokens, outtokens);

#if DYNAMIC_BUCKET_USE_LEVEL_BUCKET
        for (i = 0; i < BUCKET_AVG_UPDATE_COUNT; i++) {
                incount += bucket->intokens[i];
                outcount += bucket->outtokens[i];
        }

        avg = (incount * USEC_PER_SEC) / BUCKET_AVG_UPDATE_INTERVAL;

        avg = (bucket->bucket.avg + avg * 7 / 10) / 2;
        avg = avg == 0 ? 1 : avg;
        level_bucket_set(&bucket->bucket, avg, avg, 1);
#else
        (void) i;
        (void) incount;
        (void) outcount;
        (void) avg;
#endif

        sy_spin_unlock(&bucket->lock);

        return 0;
err_ret:
        return ret;
}

int dynamic_bucket_get_intokens(dynamic_bucket_t *bucket, uint64_t *_avg, uint64_t *_max, uint64_t *_min)
{
        int i, count;
        uint64_t total = 0, max = 0, min = UINT64_MAX;

        count =  _min(bucket->last - bucket->begin + 1, BUCKET_AVG_UPDATE_COUNT);
        for (i = 0; i < count; i++) {
                total += bucket->intokens[i];
                if (max < bucket->intokens[i])
                        max = bucket->intokens[i];
                if (min > bucket->intokens[i])
                        min = bucket->intokens[i];
        }

        *_avg = total / count;
        if (_max)
                *_max = max;
        if (_min)
                *_min = min;

        return 0;
}

int dynamic_bucket_get_outtokens(dynamic_bucket_t *bucket, uint64_t *_avg, uint64_t *_max, uint64_t *_min)
{
        int i, count;
        uint64_t total = 0, max = 0, min = UINT64_MAX;

        count =  _min(bucket->last - bucket->begin + 1, BUCKET_AVG_UPDATE_COUNT);
        for (i = 0; i < count; i++) {
                total += bucket->outtokens[i];
                if (max < bucket->outtokens[i])
                        max = bucket->outtokens[i];
                if (min > bucket->outtokens[i])
                        min = bucket->outtokens[i];
        }

        *_avg = total / count;
        if (_max)
                *_max = max;
        if (_min)
                *_min = min;

        return 0;
}

int dynamic_bucket_get_outtokens_hour(dynamic_bucket_t *bucket, uint64_t *_avg, uint64_t *_max, uint64_t *_min)
{
        int i, count;
        uint64_t total = 0, max = 0, min = UINT64_MAX;

        count =  _min(bucket->lasthour - bucket->beginhour + 1, BUCKET_HOUR_UPDATE_COUNT);
        for (i = 0; i < count; i++) {
                total += bucket->outtokens1hour[i];
                if (max < bucket->outtokens1hour[i])
                        max = bucket->outtokens1hour[i];
                if (min > bucket->outtokens1hour[i])
                        min = bucket->outtokens1hour[i];
        }

        *_avg = total / (count * SEC_PER_HOU);
        if (_max)
                *_max = max / SEC_PER_HOU;
        if (_min)
                *_min = min / SEC_PER_HOU;

        return 0;
}

#if !DYNAMIC_BUCKET_USE_LEVEL_BUCKET
static int __dynamic_bucket_request(dynamic_bucket_t *bucket, suseconds_t *delay, long size, int empty, int keep)
{
        int ret;

        if (empty < keep) {
                uint64_t inavg, outavg;
                double last_factor = keep - empty;

                dynamic_bucket_get_intokens(bucket, &inavg, NULL, NULL);
                dynamic_bucket_get_outtokens(bucket, &outavg, NULL, NULL);

                /*
                 * iops2 = 1 / (1/iops1 + delay)
                 * Q = iops * size
                 * delay = size * (1/Q2 - 1/Q1)
                 */
                double q_goal = (double)inavg / last_factor;
                if ((double)outavg > 0.0 && q_goal > 0.0 && (double)outavg > q_goal) {
                        *delay = (int64_t) (USEC_PER_SEC * size * (1.0 / q_goal - 1.0 / (double)outavg));

                        DBUG("delay %jd = %u*%ju*(1.0/%0.2f-1.0/%0.2f), limit %0.2f f %0.2f\n",
                                        *delay,
                                        USEC_PER_SEC,
                                        size,
                                        q_goal,
                                        (double)outavg,
                                        (double)inavg,
                                        last_factor);

                        *delay = _min(*delay, USEC_PER_SEC * 3);

                        ret = ECANCELED;
                        goto err_ret;
                } else if ((double)outavg > q_goal) {
                        *delay = USEC_PER_SEC * 3;
                        DBUG("delay %jd = %u*%ju*(1.0/%0.2f-1.0/%0.2f), limit %0.2f f %0.2f\n",
                                        *delay,
                                        USEC_PER_SEC,
                                        size,
                                        q_goal,
                                        (double)outavg,
                                        (double)inavg,
                                        last_factor);


                        ret = ECANCELED;
                        goto err_ret;
                } else {
                        DBUG("in %0.2f out %0.2f\n", (double)inavg, (double)outavg);
                }
        }

        return 0;
err_ret:
        return ret;
}
#endif

int dynamic_bucket_request(dynamic_bucket_t *bucket, suseconds_t *delay, long size, int empty, int keep)
{
        int ret;

        ret = sy_spin_lock(&bucket->lock);
        if (unlikely(ret))
                GOTO(err_ret, ret);

#if DYNAMIC_BUCKET_USE_LEVEL_BUCKET
        (void) empty;
        (void) keep;
        ret = level_bucket_request(&bucket->bucket, delay, size);
#else
        ret = __dynamic_bucket_request(bucket, delay, size, empty, keep);
#endif
        if (unlikely(ret)) {
                if (ret == ECANCELED)
                        goto err_lock;
                else
                        GOTO(err_lock, ret);
        }

        sy_spin_unlock(&bucket->lock);

        return 0;
err_lock:
        sy_spin_unlock(&bucket->lock);
err_ret:
        return ret;
}
